//! The client cache, which stores a read-only replica of a subset of a remote database.
//!
//! This module is internal, and may incompatibly change without warning.

use crate::callbacks::CallbackId;
use crate::db_connection::{PendingMutation, SharedCell};
use crate::spacetime_module::{InModule, SpacetimeModule, TableUpdate, WithBsatn};
use anymap::{any::Any, Map};
use bytes::Bytes;
use core::any::type_name;
use core::hash::Hash;
use futures_channel::mpsc;
use spacetimedb_data_structures::map::{DefaultHashBuilder, Entry, HashCollectionExt, HashMap};
use std::marker::PhantomData;
use std::sync::Arc;

/// A local mirror of the subscribed rows of one table in the database.
pub struct TableCache<Row> {
    /// A map of row-bytes to rows.
    ///
    /// The keys are BSATN-serialized representations of the values.
    /// Storing both the bytes and the deserialized rows allows us to have a `HashMap`
    /// even when `Row` is not `Hash + Eq`, e.g. for row types which contain floats.
    /// We also suspect that hashing and equality comparisons for byte arrays
    /// are more efficient than for domain types,
    /// as they can be implemented directly via SIMD without skipping padding
    /// or branching on enum variants.
    pub(crate) entries: HashMap<Bytes, RowEntry<Row>>,

    /// Each of the unique indices on this table.
    ///
    /// The values of this map will all be instances of `UniqueIndexImpl`.
    /// The boxing and `dyn` dispatch is necessary to erase the type of the indexed column.
    ///
    /// Entries are added to this map during [`crate::DbConnectionBuilder::build`],
    /// via a `register_table` function autogenerated for each table.
    pub(crate) unique_indices: HashMap<&'static str, Box<dyn UniqueIndexDyn<Row = Row>>>,
}

/// Stores an entry of the typed row value together with its ref count in the table cache.
pub(crate) struct RowEntry<Row> {
    /// The typed row value, interpreted from raw BSATN bytes.
    row: Row,

    /// The reference count of `row`
    /// keeping track of how many queries where `row` is live for the same table.
    ///
    /// The count may be more than 1 if there are overlapping queries in a subscription set.
    /// For example, assuming you are using SpacetimeDB to run a library, given:
    /// ```sql,ignore
    /// to_cull = SELECT * FROM book WHERE book.condition < OKAY;
    /// fantasy = SELECT * FROM book WHERE book.genre = "fantasy";
    /// ```
    /// The query `to_cull` may intersect with `fantasy`, leaving us with e.g., a `ref_count = 2`.
    ///
    /// This is stored as a `u32` as exceeding `u32::MAX` won't happen.
    /// We could probably get away with `u16`, but playing it a bit safe here.
    ref_count: u32,
}

// Can't derive this because the `Row` generic messes us up.
impl<Row> Default for TableCache<Row> {
    fn default() -> Self {
        Self {
            entries: Default::default(),
            unique_indices: Default::default(),
        }
    }
}

type RowEventMap<'r, Row> = HashMap<&'r [u8], &'r Row>;

/// The diff result of applying [`TableUpdate`] to a [`TableCache`].
///
/// Initially on construction via [`ClientCache::apply_diff_to_table`],
/// the list of updates (`update_deletes.zip(update_inserts)`) is empty
/// and must be populated by using [`TableAppliedDiff::with_updates_by_pk`]
/// by passing a projection from rows to the primary key of the table.
///
/// The set `deletes` is disjoint with `update_deletes`
/// and the set `inserts` with `update_inserts`.
/// When the latter sets are populated, they are *moved* from the former.
pub struct TableAppliedDiff<'r, Row> {
    /// The unique set of semantic deletes ("evictions") from the client cache.
    deletes: RowEventMap<'r, Row>,
    /// The unique set of semantic inserts from the client cache.
    inserts: RowEventMap<'r, Row>,
    /// The delete part of the unique set of semantic updates from the client cache.
    /// For every element in this list there is a corresponding one in `update_inserts`.
    update_deletes: Vec<&'r Row>,
    /// The insert part of the unique set of semantic updates from the client cache.
    /// For every element in this list there is a corresponding one in `update_deletes`.
    update_inserts: Vec<&'r Row>,
}

impl<Row> Default for TableAppliedDiff<'_, Row> {
    fn default() -> Self {
        Self {
            deletes: <_>::default(),
            inserts: <_>::default(),
            update_deletes: <_>::default(),
            update_inserts: <_>::default(),
        }
    }
}

impl<'r, Row> TableAppliedDiff<'r, Row> {
    /// Returns the applied diff restructured
    /// with row updates where deletes and inserts are found according to `derive_pk`.
    pub fn with_updates_by_pk<Pk: Eq + Hash>(mut self, derive_pk: impl Fn(&Row) -> &Pk) -> Self {
        self.derive_updates(derive_pk);
        self
    }

    /// Given this applied diff and a function `derive_pk` to extract the primary key,
    /// this restructures the diff such that row updates are found.
    fn derive_updates<Pk: Eq + Hash>(&mut self, derive_pk: impl Fn(&Row) -> &Pk) {
        if self.deletes.is_empty() {
            return;
        }

        // Compute the PK -> Row map for deletes.
        let mut delete_pks =
            <HashMap<_, _, DefaultHashBuilder> as HashCollectionExt>::with_capacity(self.deletes.len());
        for (&bsatn, &row) in self.deletes.iter() {
            let pk = derive_pk(row);
            delete_pks.insert(pk, (bsatn, row));
        }

        // Compute the PK -> Row for inserts,
        // removing from inserts and deletes if there is a match in deletions.
        self.update_inserts = self
            .inserts
            .extract_if(|_, ins_row| {
                let pk = derive_pk(ins_row);
                let Some((del_bsatn, del_row)) = delete_pks.get(pk) else {
                    return false;
                };
                self.update_deletes.push(del_row);
                let _deleted = self.deletes.remove(del_bsatn);
                debug_assert!(_deleted.is_some());
                true
            })
            .map(|(_, ins_row)| ins_row)
            .collect::<Vec<_>>();
    }

    /// Returns whether the applied diff is empty.
    pub(super) fn is_empty(&self) -> bool {
        self.deletes.is_empty()
            && self.inserts.is_empty()
            && self.update_deletes.is_empty()
            && self.update_inserts.is_empty()
    }

    /// Returns the deleted rows in this diff.
    pub(super) fn deletes(&self) -> impl '_ + Iterator<Item = &'r Row> {
        self.deletes.values().copied()
    }

    /// Returns the inserted rows in this diff.
    pub(super) fn inserts(&self) -> impl '_ + Iterator<Item = &'r Row> {
        self.inserts.values().copied()
    }

    /// Returns the updated rows in this diff.
    /// This will be empty if [`Self::derive_updates`] never ran.
    pub(super) fn updates(&self) -> impl '_ + Iterator<Item = (&'r Row, &'r Row)> {
        self.update_deletes
            .iter()
            .copied()
            .zip(self.update_inserts.iter().copied())
    }
}

impl<Row: Clone + Send + Sync + 'static> TableCache<Row> {
    fn handle_delete<'r>(
        &mut self,
        inserts: &mut RowEventMap<'_, Row>,
        deletes: &mut RowEventMap<'r, Row>,
        delete: &'r WithBsatn<Row>,
    ) {
        // Extract the entry and decrement the `ref_count`.
        // Only create a delete event if `ref_count = 0`.
        let Entry::Occupied(mut entry) = self.entries.entry(delete.bsatn.clone()) else {
            // We're guaranteed to never hit this as long as we apply inserts before deletes.
            unreachable!("a delete update should correspond to an existing row in the table cache");
        };
        let ref_count = &mut entry.get_mut().ref_count;
        *ref_count -= 1;
        if *ref_count == 0 {
            entry.remove();
            deletes.insert(&delete.bsatn, &delete.row);

            // While one might think the host never sends us a delete-insert pair for the same row `r0`,
            // it actually may, given the right joins.
            //
            // For example, consider three tables `r`, `s`, and `t`.
            // Let's suppose the client has subscribed to `r ⋉ s` and `r ⋉ t`:
            // ```sql
            // SELECT r.* FROM r JOIN s ON r.id = s.id;
            // SELECT r.* FROM r JOIN t ON r.id = t.id;
            // ```
            //
            // A transaction then:
            // - deletes a row `t0` which results in `delete r0` being sent.
            // - inserts a row `s0` which results in `insert r0` being sent.
            //
            // That is, we end up with `[delete r0, insert r0]`.
            inserts.remove(&*delete.bsatn);
        }
    }

    fn handle_insert<'r>(&mut self, inserts: &mut RowEventMap<'r, Row>, insert: &'r WithBsatn<Row>) {
        let entry = self.entries.entry(insert.bsatn.clone());
        let entry = entry.or_insert_with(|| {
            // First time inserting this row, so let's add an insertion event.
            inserts.insert(&insert.bsatn, &insert.row);
            RowEntry {
                row: insert.row.clone(),
                ref_count: 0,
            }
        });
        entry.ref_count += 1;
    }

    /// Apply all the deletes and inserts recorded in `diff`.
    /// Return a [`TableAppliedDiff`] which encodes the actual changes to rows present or deleted
    /// after taking into account rows' refcounts.
    ///
    /// The resulting [`TableAppliedDiff`] will contain only deletes and inserts.
    /// Its `update_deletes` and `update_inserts` fields will be empty.
    /// The caller should use [`TableAppliedDiff::with_updates_by_pk`] to merge delete/insert pairs
    /// and populate the `update_*` fields.
    fn apply_diff<'r>(&mut self, diff: &'r TableUpdate<Row>) -> TableAppliedDiff<'r, Row> {
        // Apply all inserts and collect all `ref_count: 0 -> 1` events.
        // Inserts must be applied before deletes to avoid the panic in `handle_delete`
        // and to avoid duplicate index insertion errors.
        let mut insert_events = <_>::default();
        for insert in &diff.inserts {
            self.handle_insert(&mut insert_events, insert);
        }

        // Apply all deletes and collect all `ref_count -> 0` events.
        let mut delete_events = <_>::default();
        for delete in &diff.deletes {
            self.handle_delete(&mut insert_events, &mut delete_events, delete);
        }

        // Update indices.
        // We apply deletes first to make space for later insertions
        // and to avoid duplicates in any unique index.
        for index in self.unique_indices.values_mut() {
            for row in delete_events.values() {
                index.remove_row(row);
            }
            for &row in insert_events.values() {
                index.add_row(row.clone());
            }
        }

        TableAppliedDiff {
            deletes: delete_events,
            inserts: insert_events,
            update_deletes: Vec::new(),
            update_inserts: Vec::new(),
        }
    }

    fn find_by_unique_index<'this>(
        &'this self,
        unique_index_name: &'static str,
        key: &'_ dyn std::any::Any,
    ) -> Option<&'this Row> {
        let index = self
            .unique_indices
            .get(unique_index_name)
            .unwrap_or_else(|| panic!("No such unique index: {unique_index_name}"));
        index.find_row(key)
    }

    /// Called by the codegen when initializing the client cache during [`crate::DbConnectionBuilder::build`].
    pub fn add_unique_constraint<Col>(&mut self, unique_index_name: &'static str, get_unique_col: fn(&Row) -> &Col)
    where
        Col: Any + Clone + std::hash::Hash + Eq + Send + Sync + std::fmt::Debug + 'static,
    {
        assert!(self.entries.is_empty(), "Cannot add a unique constraint to a populated table; constraints should only be added during initialization, before subscribing to any rows.");
        if self
            .unique_indices
            .insert(
                unique_index_name,
                Box::new(UniqueIndexImpl {
                    get_unique_col,
                    rows: Default::default(),
                }),
            )
            .is_some()
        {
            panic!("Duplicate unique constraint name {unique_index_name}");
        }
    }
}

/// A local mirror of the subscribed subset of the database.
pub struct ClientCache<M: SpacetimeModule + ?Sized> {
    /// "keyed" on the type `HashMap<&'static str, TableCache<Row>`.
    ///
    /// The strings are table names, since we may have multiple tables with the same row type.
    tables: Map<dyn Any + Send + Sync>,

    _module: PhantomData<M>,
}

impl<M: SpacetimeModule> Default for ClientCache<M> {
    fn default() -> Self {
        Self {
            tables: Map::new(),
            _module: PhantomData,
        }
    }
}

impl<M: SpacetimeModule> ClientCache<M> {
    /// Get a handle on the [`TableCache`] which stores rows of type `Row` for the table `table_name`.
    pub(crate) fn get_table<Row: InModule<Module = M> + Send + Sync + 'static>(
        &self,
        table_name: &'static str,
    ) -> Option<&TableCache<Row>> {
        self.tables
            .get::<HashMap<&'static str, TableCache<Row>>>()
            .and_then(|tables_of_row_type| tables_of_row_type.get(table_name))
    }

    /// Called internally when updating the client cache in response to WebSocket messages,
    /// and by the codegen when initializing the client cache during [`crate::DbConnectionBuilder::build`].
    pub fn get_or_make_table<Row: InModule<Module = M> + Send + Sync + 'static>(
        &mut self,
        table_name: &'static str,
    ) -> &mut TableCache<Row> {
        self.tables
            .entry::<HashMap<&'static str, TableCache<Row>>>()
            .or_insert_with(Default::default)
            .entry(table_name)
            .or_default()
    }

    /// Apply all the mutations in `diff`
    /// to the [`TableCache`] which stores rows of type `Row` for the table `table_name`.
    pub fn apply_diff_to_table<'r, Row: InModule<Module = M> + Clone + Send + Sync + 'static>(
        &mut self,
        table_name: &'static str,
        diff: &'r TableUpdate<Row>,
    ) -> TableAppliedDiff<'r, Row> {
        if diff.is_empty() {
            return <_>::default();
        }

        let table = self.get_or_make_table::<Row>(table_name);

        table.apply_diff(diff)
    }
}

/// Internal implementation of a generated `TableHandle` struct,
/// which mediates access to a table in the client cache.
///
/// `TableHandle`s don't actually hold a direct reference to the table they access,
/// as that would require both gnarly lifetimes and also a `MutexGuard` on the client cache.
/// Instead, they hold an `Arc<Mutex>` on the whole [`ClientCache`],
/// with every operation through the table handle
/// acquiring the lock only for the duration of the operation,
/// calling [`ClientCache::get_table`] and then discarding its reference before returning.
pub struct TableHandle<Row: InModule> {
    pub(crate) client_cache: SharedCell<ClientCache<Row::Module>>,
    /// Handle on the connection's `pending_mutations_send` channel,
    /// so we can send callback-related [`PendingMutation`] messages.
    pub(crate) pending_mutations: mpsc::UnboundedSender<PendingMutation<Row::Module>>,

    /// The name of the table.
    pub(crate) table_name: &'static str,
}

impl<Row: InModule> Clone for TableHandle<Row> {
    fn clone(&self) -> Self {
        Self {
            client_cache: Arc::clone(&self.client_cache),
            pending_mutations: self.pending_mutations.clone(),
            table_name: self.table_name,
        }
    }
}

impl<Row: InModule + Send + Sync + Clone + 'static> TableHandle<Row> {
    /// Read something out of the [`TableCache`] which this `TableHandle` accesses.
    fn with_table_cache<Res>(&self, get: impl FnOnce(&TableCache<Row>) -> Res) -> Res {
        let client_cache = self.client_cache.lock().unwrap();
        client_cache
            .get_table::<Row>(self.table_name)
            .map(get)
            .unwrap_or_else(|| panic!("No such table: {}", self.table_name))
    }

    /// Called by the autogenerated implementation of the [`crate::Table`] method of the same name.
    pub fn count(&self) -> u64 {
        self.with_table_cache(|table| table.entries.len() as u64)
    }

    /// Called by the autogenerated implementation of the [`crate::Table`] method of the same name.
    pub fn iter(&self) -> impl Iterator<Item = Row> {
        self.with_table_cache(|table| table.entries.values().map(|e| e.row.clone()).collect::<Vec<_>>())
            .into_iter()
    }

    /// See [`DbContextImpl::queue_mutation`].
    fn queue_mutation(&self, mutation: PendingMutation<Row::Module>) {
        self.pending_mutations.unbounded_send(mutation).unwrap();
    }

    /// Called by the autogenerated implementation of the [`crate::Table`] method of the same name.
    pub fn on_insert(
        &self,
        mut callback: impl FnMut(&<Row::Module as SpacetimeModule>::EventContext, &Row) + Send + 'static,
    ) -> CallbackId {
        let callback_id = CallbackId::get_next();
        self.queue_mutation(PendingMutation::AddInsertCallback {
            table: self.table_name,
            callback: Box::new(move |ctx, row| {
                let row = row.downcast_ref::<Row>().unwrap();
                callback(ctx, row);
            }),
            callback_id,
        });
        callback_id
    }

    /// Called by the autogenerated implementation of the [`crate::Table`] method of the same name.
    pub fn remove_on_insert(&self, callback: CallbackId) {
        self.queue_mutation(PendingMutation::RemoveInsertCallback {
            table: self.table_name,
            callback_id: callback,
        });
    }

    /// Called by the autogenerated implementation of the [`crate::Table`] method of the same name.
    pub fn on_delete(
        &self,
        mut callback: impl FnMut(&<Row::Module as SpacetimeModule>::EventContext, &Row) + Send + 'static,
    ) -> CallbackId {
        let callback_id = CallbackId::get_next();
        self.queue_mutation(PendingMutation::AddDeleteCallback {
            table: self.table_name,
            callback: Box::new(move |ctx, row| {
                let row = row.downcast_ref::<Row>().unwrap();
                callback(ctx, row);
            }),
            callback_id,
        });
        callback_id
    }

    /// Called by the autogenerated implementation of the [`crate::Table`] method of the same name.
    pub fn remove_on_delete(&self, callback: CallbackId) {
        self.queue_mutation(PendingMutation::RemoveDeleteCallback {
            table: self.table_name,
            callback_id: callback,
        });
    }

    /// Called by the autogenerated implementation of the [`crate::TableWithPrimaryKey`] method of the same name.
    pub fn on_update(
        &self,
        mut callback: impl FnMut(&<Row::Module as SpacetimeModule>::EventContext, &Row, &Row) + Send + 'static,
    ) -> CallbackId {
        let callback_id = CallbackId::get_next();
        self.queue_mutation(PendingMutation::AddUpdateCallback {
            table: self.table_name,
            callback: Box::new(move |ctx, old, new| {
                let old = old.downcast_ref::<Row>().unwrap();
                let new = new.downcast_ref::<Row>().unwrap();
                callback(ctx, old, new);
            }),
            callback_id,
        });
        callback_id
    }

    /// Called by the autogenerated implementation of the [`crate::TableWithPrimaryKey`] method of the same name.
    pub fn remove_on_update(&self, callback: CallbackId) {
        self.queue_mutation(PendingMutation::RemoveUpdateCallback {
            table: self.table_name,
            callback_id: callback,
        });
    }

    /// Called by autogenerated unique index access methods.
    pub fn get_unique_constraint<Col>(&self, constraint_name: &'static str) -> UniqueConstraintHandle<Row, Col> {
        UniqueConstraintHandle {
            table_handle: self.clone(),
            unique_index_name: constraint_name,
            _phantom: PhantomData,
        }
    }
}

/// A fake implementation of a unique index.
///
/// This struct should allow efficient point queries of a particular field in the table,
/// but our current implementation just does a full scan.
///
/// Like [`TableHandle`], unique constraint handles don't hold a direct reference to their table
/// or an index within it. (No such index currently exists, anyways.)
/// Instead, they hold a handle on the whole [`ClientCache`],
/// and acquire short-lived exclusive access to it during operations.
pub struct UniqueConstraintHandle<Row: InModule, Col> {
    table_handle: TableHandle<Row>,
    unique_index_name: &'static str,
    _phantom: PhantomData<HashMap<Col, Row>>,
}

impl<
        Row: Clone + InModule + Send + Sync + 'static,
        Col: std::any::Any + Eq + std::hash::Hash + Clone + Send + Sync + std::fmt::Debug + 'static,
    > UniqueConstraintHandle<Row, Col>
{
    pub fn find(&self, col_val: &Col) -> Option<Row> {
        self.table_handle
            .with_table_cache(|table| table.find_by_unique_index(self.unique_index_name, col_val).cloned())
    }
}

/// [`UniqueIndexImpl`], but with its `Col` type parameter erased.
pub trait UniqueIndexDyn: Send + Sync + 'static {
    /// The `Row` type parameter to [`UniqueIndexImpl`]; the type of rows in the indexed table.
    type Row: Clone + Send + Sync + 'static;

    /// Insert a new row into the index.
    ///
    /// Panics if an existing row has the same value in the indexed column.
    fn add_row(&mut self, row: Self::Row);
    /// Delete a row from the index.
    ///
    /// Panics if no resident row in the index has the same value in the indexed column.
    ///
    /// Does not check that the row removed from the index is exactly equal to `row`,
    /// only that they have matching values in the indexed column.
    fn remove_row(&mut self, row: &Self::Row);
    /// Look up the row with `key` as its value in the indexed column,
    /// if such a row is resident in the client cache.
    ///
    /// Panics if `key` is not of the same type as the indexed column.
    fn find_row<'this>(&'this self, key: &'_ dyn std::any::Any) -> Option<&'this Self::Row>;
}

/// A unique index on a table with rows of type `Row`, indexing a column of type `Col`.
pub struct UniqueIndexImpl<Row, Col> {
    /// All the rows in the table, indexed by their unique column.
    ///
    /// Unpleasant hack: each unique index stores duplicates of the entire row,
    /// rather than, say, an index or reference into the containing table.
    /// This is because [`HashMap`] does not expose stable indices/references other than hashes,
    /// and the chief competitor which does, `IndexMap`, loses index/reference stability
    /// when removing elements, which would require complicated fixups.
    ///
    /// One could imagine storing an `IntMap` from opaque indices/references to rows,
    /// having the outer [`TableCache`] map row bytes to these integers,
    /// and having unique indices do the same.
    rows: HashMap<Col, Row>,
    /// Given a row, get the value of the indexed column.
    get_unique_col: fn(&Row) -> &Col,
}

impl<Row, Col> UniqueIndexDyn for UniqueIndexImpl<Row, Col>
where
    Row: Clone + Send + Sync + 'static,
    Col: Any + Clone + std::hash::Hash + Eq + Send + Sync + std::fmt::Debug + 'static,
{
    type Row = Row;
    fn add_row(&mut self, row: Self::Row) {
        let col = (self.get_unique_col)(&row).clone();
        if let Some(prev_row) = self.rows.insert(col, row) {
            panic!(
                "Duplicated entry in unique index at key {:?}, for type {}",
                (self.get_unique_col)(&prev_row),
                type_name::<Row>()
            );
        }
    }
    fn remove_row(&mut self, row: &Self::Row) {
        let col = (self.get_unique_col)(row);
        self.rows
            .remove(col)
            .expect("UniqueIndexDyn::remove_row for non-present row");
    }
    fn find_row<'this>(&'this self, key: &'_ dyn std::any::Any) -> Option<&'this Self::Row> {
        let col = key
            .downcast_ref::<Col>()
            .expect("UniqueIndexDyn::find_row with key of incorrect type");
        self.rows.get(col)
    }
}
